Skip to content

bump version to 0.6.12#3388

Merged
aleozlx merged 1 commit into
mainfrom
bump-version-0.6.12
May 22, 2026
Merged

bump version to 0.6.12#3388
aleozlx merged 1 commit into
mainfrom
bump-version-0.6.12

Conversation

@aleozlx
Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx commented May 21, 2026

Description

Bump version to 0.6.12 for release.

Related Issues (Gated-by PRs)

https://github.com/flashinfer-ai/flashinfer/issues?q=is%3Aopen+label%3Av0.6.12

Reviewer Notes

API changes review

API changes since v0.6.11.post3, using new tool

diff -u \
  <(scripts/list_apis.sh -d -p --ref v0.6.11.post3) \
  <(scripts/list_apis.sh -d -p)

--- /tmp/api_baseline.txt	2026-05-21 16:07:23.252004287 -0700
+++ /tmp/api_head.txt	2026-05-21 16:07:23.316004287 -0700
@@ -251,6 +251,8 @@
     shared_expert_output: Optional[torch.Tensor] = None,
     # ===== Group quant parameters =====
     block_quant_group_size: Optional[int] = None,
+    # ===== RMSNorm variant =====
+    weight_bias: float = 0.0,
 ) -> torch.Tensor:
 [Global Functions]
 @flashinfer_api
@@ -513,6 +515,7 @@
     out_dtype: Optional[torch.dtype] = None,
     is_var_seq: bool = True,
     enable_pdl: Optional[bool] = None,
+    sinks: Optional[torch.Tensor] = None,
 ) -> torch.Tensor:
 class BatchPrefillCuteDSLWrapper:
     @flashinfer_api
@@ -759,7 +762,11 @@
     skip_softmax_threshold_scale_factor: Optional[float] = None,
     kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
     uses_shared_paged_kv_idx: bool = True,
-) -> Union[torch.Tensor, FP4Tensor]:
+    lse: Optional[torch.Tensor] = None,
+    return_lse: bool = False,
+) -> Union[
+    torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]
+]:
 @flashinfer_api(trace=xqa_batch_decode_trace)
 def xqa_batch_decode_with_kv_cache(
     query: torch.Tensor,
@@ -898,6 +905,7 @@
     weight_layout: int = WeightLayout.BlockMajorK,
     do_finalize: bool = True,
     enable_pdl: bool = True,
+    gemm1_lora_delta: Optional[torch.Tensor] = None,
     tune_max_num_tokens: int = 8192,
     activation_type: int = ActivationType.Swiglu.value,
     routing_replay_out: Optional[torch.Tensor] = None,
@@ -987,6 +995,7 @@
     weight_layout: int = 0,
     do_finalize: bool = True,
     enable_pdl: Optional[bool] = None,
+    gemm1_lora_delta: Optional[torch.Tensor] = None,
     output: Optional[torch.Tensor] = None,
     tune_max_num_tokens: int = 8192,
     fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
@@ -1034,7 +1043,7 @@
 
 @flashinfer_api(trace=trtllm_fp4_block_scale_routed_moe_trace)
 def trtllm_fp4_block_scale_routed_moe(
-    topk_ids: torch.Tensor,
+    topk_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
     routing_bias: Optional[torch.Tensor],
     hidden_states: torch.Tensor,
     hidden_states_scale: Optional[torch.Tensor],
@@ -1096,6 +1105,34 @@
     norm_topk_prob: bool = True,
     routing_replay_out: Optional[torch.Tensor] = None,
 ) -> List[torch.Tensor]:
+
+
+@flashinfer_api
+def trtllm_mxint4_block_scale_routed_moe(
+    topk_ids: torch.Tensor,
+    hidden_states: torch.Tensor,
+    gemm1_weights: torch.Tensor,
+    gemm1_weights_scale: torch.Tensor,
+    gemm1_alpha: Optional[torch.Tensor],
+    gemm1_beta: Optional[torch.Tensor],
+    gemm1_clamp_limit: Optional[torch.Tensor],
+    gemm2_weights: torch.Tensor,
+    gemm2_weights_scale: torch.Tensor,
+    num_experts: int,
+    top_k: int,
+    n_group: Optional[int],
+    topk_group: Optional[int],
+    intermediate_size: int,
+    local_expert_offset: int,
+    local_num_experts: int,
+    routed_scaling_factor: Optional[float],
+    routing_method_type: int = 0,
+    do_finalize: bool = True,
+    enable_pdl: Optional[bool] = None,
+    gemm1_lora_delta: Optional[torch.Tensor] = None,
+    output: Optional[torch.Tensor] = None,
+    tune_max_num_tokens: int = 8192,
+) -> List[torch.Tensor]:
 [Global Functions]
 @flashinfer_api(trace=b12x_fused_moe_trace)
 def b12x_fused_moe(
@@ -1117,8 +1154,6 @@
     output_dtype: torch.dtype = torch.bfloat16,
     activation: str = "silu",
     activation_precision: str = "fp4",
-    quant_mode: Optional[str] = None,
-    source_format: str = "modelopt",
 ) -> torch.Tensor:
 class B12xMoEWrapper:
     @flashinfer_api
@@ -1136,8 +1171,6 @@
         device: str = "cuda",
         activation: str = "silu",
         activation_precision: str = "fp4",
-        quant_mode: Optional[str] = None,
-        source_format: str = "modelopt",
     ):
 
     @flashinfer_api(trace=b12x_moe_wrapper_run_trace)
@@ -1477,8 +1510,6 @@
     out: Optional[torch.Tensor] = None,
     backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas",
 ):
-
-
 @flashinfer_api(trace=bmm_fp8_trace)
 def bmm_fp8(
     A: torch.Tensor,
@@ -1524,7 +1555,7 @@
     out_dtype: Optional[torch.dtype] = None,
     backend: Literal["cutlass", "trtllm"] = "cutlass",
 ):
-@flashinfer_api
+@flashinfer_api(trace=gemm_fp8_nt_groupwise_trace)
 def gemm_fp8_nt_groupwise(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -1712,8 +1743,17 @@
     sf_dtype: str,
     c_dtype: str,
     sf_vec_size: int,
+    topk_weights: Optional[torch.Tensor] = None,
+    idx_src_info: Optional[torch.Tensor] = None,
+    rank_src_info: Optional[torch.Tensor] = None,
+    out_ptrs: Optional[torch.Tensor] = None,
+    num_ranks: int = 0,
     dst_signals: Optional[torch.Tensor] = None,
     sm_count: Optional[int] = None,
+    barrier_flag_local: Optional[torch.Tensor] = None,
+    barrier_flag_multicast: Optional[torch.Tensor] = None,
+    is_combine_fusion: bool = False,
+    is_swap_ab: bool = False,
     **kwargs,
 ):
 [Global Functions]
@@ -1722,14 +1762,21 @@
     mat_a: torch.Tensor,
     mat_b: torch.Tensor,
     out: torch.Tensor,
-    launch_with_pdl: bool = False,
+    launch_with_pdl: bool = True,
 ) -> None:
 @flashinfer_api(trace=mm_M1_16_K7168_N256_trace)
 def mm_M1_16_K7168_N256(
     mat_a: torch.Tensor,
     mat_b: torch.Tensor,
     out: torch.Tensor,
-    launch_with_pdl: bool = False,
+    launch_with_pdl: bool = True,
+) -> None:
+@flashinfer_api(trace=mm_M1_16_K6144_N256_trace)
+def mm_M1_16_K6144_N256(
+    mat_a: torch.Tensor,
+    mat_b: torch.Tensor,
+    out: torch.Tensor,
+    launch_with_pdl: bool = True,
 ) -> None:
 @flashinfer_api(trace=tinygemm_bf16_trace)
 def tinygemm_bf16(
@@ -1826,6 +1873,36 @@
     tactic: int = -1,
 ) -> torch.Tensor:
 [Global Functions]
+@flashinfer_api
+def checkpointing_ssu(
+    state: torch.Tensor,
+    old_x: torch.Tensor,
+    old_B: torch.Tensor,
+    old_dt: torch.Tensor,
+    old_cumAdt: torch.Tensor,
+    cache_buf_idx: torch.Tensor,
+    prev_num_accepted_tokens: torch.Tensor,
+    x: torch.Tensor,
+    dt: torch.Tensor,
+    A: torch.Tensor,
+    B: torch.Tensor,
+    C: torch.Tensor,
+    out: torch.Tensor,
+    D: Optional[torch.Tensor] = None,
+    z: Optional[torch.Tensor] = None,
+    dt_bias: Optional[torch.Tensor] = None,
+    dt_softplus: bool = False,
+    state_batch_indices: Optional[torch.Tensor] = None,
+    pad_slot_id: int = -1,
+    state_scale: Optional[torch.Tensor] = None,
+    rand_seed: Optional[torch.Tensor] = None,
+    philox_rounds: int = 10,
+    d_split: Optional[int] = None,
+    cu_seqlens: Optional[torch.Tensor] = None,
+    max_seqlen: Optional[int] = None,
+    enable_pdl: bool = False,
+) -> torch.Tensor:
+[Global Functions]
 @flashinfer_api(trace=selective_state_update_trace)
 def selective_state_update(
     state: torch.Tensor,
@@ -1966,6 +2043,7 @@
         kv_len: Optional[torch.Tensor] = None,
         page_table: Optional[torch.Tensor] = None,
         return_lse_base_on_e: bool = False,
+        o_scale: Optional[float] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
 
 
@@ -1991,7 +2069,10 @@
     backend: str = "auto",
     is_var_seq: bool = True,
     uses_shared_paged_kv_idx: bool = True,
-) -> torch.Tensor:
+    lse: Optional[torch.Tensor] = None,
+    return_lse: bool = False,
+    cute_dsl_impl: str = "auto",
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
 
 
 @flashinfer_api(trace=xqa_batch_decode_mla_trace)
@@ -2252,6 +2333,44 @@
     norm_out: Optional[torch.Tensor] = None,
     sf_out: Optional[torch.Tensor] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
+    qkv,
+    q_weight,
+    k_weight,
+    **kwargs,
+):
+
+
+@flashinfer_api
+def fused_qk_rmsnorm_rope(
+    qkv: torch.Tensor,
+    q_weight: torch.Tensor,
+    k_weight: torch.Tensor,
+    *,
+    ppf: int,
+    pph: int,
+    ppw: int,
+    num_frame_channels: int,
+    num_height_channels: int,
+    num_width_channels: int,
+    num_heads_q: int,
+    num_heads_k: int,
+    num_heads_v: int,
+    head_dim: int,
+    eps: float = 1e-6,
+    base: float = 10000.0,
+    interleave: bool = True,
+    factor: float = 1.0,
+    low: float = 0.0,
+    high: float = 0.0,
+    attention_factor: float = 1.0,
+    is_qk_norm: bool = True,
+    output_fp8: bool = False,
+    output_quant_scale: float = 1.0,
+    v_quant_scale: float = 1.0,
+    q_out: Optional[torch.Tensor] = None,
+    k_out: Optional[torch.Tensor] = None,
+    v_out: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 [Global Functions]
 @flashinfer_api
 def get_batch_indices_positions(
@@ -2730,7 +2849,11 @@
     skip_softmax_threshold_scale_factor: Optional[float] = None,
     uses_shared_paged_kv_idx: bool = True,
     causal: bool = True,
-) -> Union[torch.Tensor, FP4Tensor]:
+    lse: Optional[torch.Tensor] = None,
+    return_lse: bool = False,
+) -> Union[
+    torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]
+]:
 
 
 @flashinfer_api(trace=fmha_v2_prefill_deepseek_trace)
@@ -2942,6 +3065,7 @@
     is_sf_swizzled_layout: bool = True,
     alignment: int = 32,
     enable_pdl: bool | None = None,
+    is_sf_8x4_layout: bool = False,
 ) -> Tuple[torch.Tensor, torch.Tensor]:

API changes since v0.6.11.post3 (old approach)

$ git diff v0.6.11.post3..main -- "*.py" | grep -B5 -A20 "@flashinfer_api"
-def _reconstruct_value(value: Any) -> Any:
+def flush_graph_dumps(synchronize: bool = True) -> int:
+    """Write CUDA-graph-deferred level-10 dumps to disk.
+
+    When ``FLASHINFER_LOGLEVEL=10`` is active inside ``torch.cuda.graph(...)``,
+    each ``@flashinfer_api`` call records input/output tensor references instead
+    of writing immediately or inserting D2H copies into the captured graph.
+    After ``g.replay()`` completes, calling this function materializes current
+    tensor values to CPU and serializes them to two places:
+
+    1. ``inputs.pt``/``outputs.pt`` (or the safetensors equivalents) in the
+       original dump directory, for backwards compatibility. These files
+       always reflect the most recent flush.
+    2. ``graph_flushes/flush_XXXX/`` under the original dump directory. These
+       immutable snapshots preserve every explicit flush, so callers can keep
+       every replay by calling ``flush_graph_dumps()`` after every replay.
+
+    Parameters
+    ----------
+    synchronize : bool, default True
+        Synchronize the current stream first to ensure the most recent
+        ``g.replay()`` has completed before materializing tensors. Set to
+        ``False`` only if you've already synchronized externally.
+
+    Returns
+    -------
--
         routing_logits,
         None,
         None,
@@ -3199,7 +3362,7 @@ def trtllm_fp4_block_scale_moe(
 
 @flashinfer_api(trace=trtllm_fp4_block_scale_routed_moe_trace)
 def trtllm_fp4_block_scale_routed_moe(
-    topk_ids: torch.Tensor,
+    topk_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
     routing_bias: Optional[torch.Tensor],
     hidden_states: torch.Tensor,
     hidden_states_scale: Optional[torch.Tensor],
@@ -3231,13 +3394,20 @@ def trtllm_fp4_block_scale_routed_moe(
     output: Optional[torch.Tensor] = None,
     tune_max_num_tokens: int = 8192,
 ) -> List[torch.Tensor]:
-    """FP4 block scale MoE operation.
+    """FP4 block scale MoE operation with pre-computed routing.
+
+    This function supports two pre-computed routing formats:
+    1. Packed format: topk_ids is a single tensor with packed (score << 16 | expert_id)
+    2. Unpacked format: topk_ids is a tuple of (topk_ids, topk_weights) tensors
 
     Args:
-        topk_ids (torch.Tensor): shape [seq_len, top_k]
-            Tensor of top-k indices and expert weights. Dtype must be int32.
--
         norm_topk_prob,
         routing_replay_out,
     )
+
+
+@flashinfer_api
+def trtllm_mxint4_block_scale_routed_moe(
+    topk_ids: torch.Tensor,
+    hidden_states: torch.Tensor,
+    gemm1_weights: torch.Tensor,
+    gemm1_weights_scale: torch.Tensor,
+    gemm1_alpha: Optional[torch.Tensor],
+    gemm1_beta: Optional[torch.Tensor],
+    gemm1_clamp_limit: Optional[torch.Tensor],
+    gemm2_weights: torch.Tensor,
+    gemm2_weights_scale: torch.Tensor,
+    num_experts: int,
+    top_k: int,
+    n_group: Optional[int],
+    topk_group: Optional[int],
+    intermediate_size: int,
+    local_expert_offset: int,
+    local_num_experts: int,
+    routed_scaling_factor: Optional[float],
+    routing_method_type: int = 0,
+    do_finalize: bool = True,
--
-    except Exception:
-        return False
-
-
 @supported_compute_capability([120, 121])
 @flashinfer_api(trace=b12x_fused_moe_trace)
 def b12x_fused_moe(
@@ -74,13 +67,11 @@ def b12x_fused_moe(
     output_dtype: torch.dtype = torch.bfloat16,
     activation: str = "silu",
     activation_precision: str = "fp4",
-    quant_mode: Optional[str] = None,
-    source_format: str = "modelopt",
 ) -> torch.Tensor:
     """Run fused MoE on SM120/SM121 using b12x CuTe DSL kernels.
 
-    The kernel takes bf16 input and runs routing, FC1, activation, FC2,
-    and scatter through the selected backend.
+    The kernel takes bf16 input and fuses quantization + routing +
+    FC1 + activation + FC2 + scatter in a single launch.
     Automatically selects micro (decode), static, or dynamic backend
     based on routed row count.
 
@@ -99,19 +90,16 @@ def b12x_fused_moe(
         w1_alpha: Per-expert global scale for FC1.
         w2_alpha: Per-expert global scale for FC2.
--
 
@@ -6387,7 +6276,7 @@ def _check_gemm_fp8_nt_groupwise_problem_size(
     },
     common_check=_check_gemm_fp8_nt_groupwise_problem_size,
 )
-@flashinfer_api
+@flashinfer_api(trace=gemm_fp8_nt_groupwise_trace)
 def gemm_fp8_nt_groupwise(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -8031,7 +7920,7 @@ def _calculate_block_scale_dims(
 
 
 @functools.lru_cache(maxsize=1024)
-def create_cudnn_execution_plans_mxfp8_gemm(
+def build_cudnn_gemm_mxfp8_graph(
     a_shape,
     a_stride,
     a_type,  # cudnn.data_type, FP8_E4M3 or FP8_E5M2
@@ -8041,7 +7930,11 @@ def create_cudnn_execution_plans_mxfp8_gemm(
     block_size,
     o_type,  # cudnn.data_type, BF16 or FP16
     device,
+    policy=None,
 ):
+    if policy is None:
+        policy = cudnn.build_plan_policy.HEURISTICS_CHOICE
--
@@ -229,6 +264,54 @@ def mm_M1_16_K7168_N256(
     )
 
 
+@backend_requirement({}, common_check=_mm_M1_16_K6144_N256_shape_checks)
+@flashinfer_api(trace=mm_M1_16_K6144_N256_trace)
+def mm_M1_16_K6144_N256(
+    mat_a: torch.Tensor,
+    mat_b: torch.Tensor,
+    out: torch.Tensor,
+    launch_with_pdl: bool = True,
+) -> None:
+    """Optimized GEMM for the router operation in GLM-MoE-DSA.
+
+    This function performs a highly optimized matrix multiplication specifically tailored
+    for the expert routing GEMM in GLM-MoE-DSA's Mixture of Experts (MoE) architecture.
+    It computes out = mat_a @ mat_b where mat_a contains token embeddings and mat_b
+    contains expert routing weights.
+
+    The implementation is optimized for the specific problem dimensions used in GLM-MoE-DSA:
+    - Hidden dimension (K): 6144
+    - Number of experts (N): 256
+    - Number of tokens (M): 1-16
+
+    Args:
+        mat_a (torch.Tensor): Input token embeddings of shape (M, K) where M is the number
--
+) -> None:
+    """Fake implementation for torch.compile() meta tensor propagation."""
+    pass
+
+
+@flashinfer_api
+def checkpointing_ssu(
+    state: torch.Tensor,
+    old_x: torch.Tensor,
+    old_B: torch.Tensor,
+    old_dt: torch.Tensor,
+    old_cumAdt: torch.Tensor,
+    cache_buf_idx: torch.Tensor,
+    prev_num_accepted_tokens: torch.Tensor,
+    x: torch.Tensor,
+    dt: torch.Tensor,
+    A: torch.Tensor,
+    B: torch.Tensor,
+    C: torch.Tensor,
+    out: torch.Tensor,
+    D: Optional[torch.Tensor] = None,
+    z: Optional[torch.Tensor] = None,
+    dt_bias: Optional[torch.Tensor] = None,
+    dt_softplus: bool = False,
+    state_batch_indices: Optional[torch.Tensor] = None,
+    pad_slot_id: int = -1,
--
         page_table: Optional[torch.Tensor] = None,
         return_lse_base_on_e: bool = False,
+        o_scale: Optional[float] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
     @flashinfer_api(trace=mla_paged_decode_trace)
@@ -489,6 +915,7 @@ class BatchMLAPagedAttentionWrapper:
         kv_len: Optional[torch.Tensor] = None,
         page_table: Optional[torch.Tensor] = None,
         return_lse_base_on_e: bool = False,
+        o_scale: Optional[float] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         r"""Run the MLA attention computation.
 
@@ -506,6 +933,7 @@ class BatchMLAPagedAttentionWrapper:
             ``head_dim_kpe`` is 64 in DeepSeek v2/v3 models.
         out : Optional[torch.Tensor]
             The output tensor, if not provided, will be allocated internally.
+            When ``o_scale`` is provided, this should be an FP8 tensor.
         lse : Optional[torch.Tensor]
             The log-sum-exp of attention logits, if not provided, will be allocated internally.
         return_lse : bool, optional
@@ -516,6 +944,10 @@ class BatchMLAPagedAttentionWrapper:
             The query length of each request, shape: ``[batch_size]``. Required when ``backend`` is ``cutlass``.
         page_table : Optional[torch.Tensor]
             The page table of the paged kv-cache, shape: ``[batch_size, num_pages]``. Required when ``backend`` is ``cutlass``.
--
+            )
+
+    return True
+
+
+@flashinfer_api
+@backend_requirement(backend_checks={}, common_check=_check_fused_qk_rmsnorm_rope)
+def fused_qk_rmsnorm_rope(
+    qkv: torch.Tensor,
+    q_weight: torch.Tensor,
+    k_weight: torch.Tensor,
+    *,
+    ppf: int,
+    pph: int,
+    ppw: int,
+    num_frame_channels: int,
+    num_height_channels: int,
+    num_width_channels: int,
+    num_heads_q: int,
+    num_heads_k: int,
+    num_heads_v: int,
+    head_dim: int,
+    eps: float = 1e-6,
+    base: float = 10000.0,
+    interleave: bool = True,
+    factor: float = 1.0,```

**Supplemental: class-wrapper overload stub changes (BatchMLAPagedAttentionWrapper.run gained `o_scale`)**

```diff
$ git diff v0.6.11.post3..main -- "flashinfer/mla/_core.py" | grep -B5 -A10 "o_scale"
     mod = gen_trtllm_gen_fmha_module()
@@ -457,6 +881,7 @@ class BatchMLAPagedAttentionWrapper:
         kv_len: Optional[torch.Tensor] = None,
         page_table: Optional[torch.Tensor] = None,
         return_lse_base_on_e: bool = False,
+        o_scale: Optional[float] = None,
     ) -> torch.Tensor: ...
 
     @overload
@@ -473,6 +898,7 @@ class BatchMLAPagedAttentionWrapper:
         kv_len: Optional[torch.Tensor] = None,
         page_table: Optional[torch.Tensor] = None,
         return_lse_base_on_e: bool = False,
+        o_scale: Optional[float] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
     @flashinfer_api(trace=mla_paged_decode_trace)
@@ -489,6 +915,7 @@ class BatchMLAPagedAttentionWrapper:
         kv_len: Optional[torch.Tensor] = None,
         page_table: Optional[torch.Tensor] = None,
         return_lse_base_on_e: bool = False,
+        o_scale: Optional[float] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         r"""Run the MLA attention computation.
 
@@ -506,6 +933,7 @@ class BatchMLAPagedAttentionWrapper:
             ``head_dim_kpe`` is 64 in DeepSeek v2/v3 models.
         out : Optional[torch.Tensor]
             The output tensor, if not provided, will be allocated internally.
+            When ``o_scale`` is provided, this should be an FP8 tensor.
         lse : Optional[torch.Tensor]
             The log-sum-exp of attention logits, if not provided, will be allocated internally.
         return_lse : bool, optional
@@ -516,6 +944,10 @@ class BatchMLAPagedAttentionWrapper:
             The query length of each request, shape: ``[batch_size]``. Required when ``backend`` is ``cutlass``.
         page_table : Optional[torch.Tensor]
             The page table of the paged kv-cache, shape: ``[batch_size, num_pages]``. Required when ``backend`` is ``cutlass``.
+        o_scale : Optional[float]
+            FP8 output dequantization scale (``real = quantized * o_scale``).
+            When provided, ``out`` must be an FP8 tensor. Only supported with
+            the ``cutlass`` backend.
         """
         if self._backend == "cutlass":
             if return_lse:
@@ -525,7 +957,26 @@ class BatchMLAPagedAttentionWrapper:
                     "profiler_buffer does not support cutlass backend for now."
                 )
             self._cached_module = get_mla_module()
-            if out is None:
+            output_scale = 1.0
+            if o_scale is not None:
+                output_scale = float(o_scale)
+                if not math.isfinite(output_scale) or output_scale <= 0.0:
+                    raise ValueError(
+                        f"o_scale must be a finite positive value, got {o_scale}"
+                    )
+                if out is None:
+                    raise ValueError(
+                        "out tensor must be provided when o_scale is used for FP8 output."
+                    )
+                if out.dtype not in (
+                    torch.float8_e4m3fn,
+                    torch.float8_e5m2,
+                ):
+                    raise ValueError(
+                        f"out must be an FP8 tensor when o_scale is provided, got {out.dtype}"
+                    )
+                check_shape_dtype_device(out, q_nope.shape, None, q_nope.device, "out")
+            elif out is None:
                 out = torch.empty_like(q_nope)
             else:
                 check_shape_dtype_device(
@@ -543,9 +994,14 @@ class BatchMLAPagedAttentionWrapper:
                 ckv_kpe_cache,
                 kv_len,
                 page_table,
+                output_scale,
             )
             return out
 
+        if o_scale is not None:
+            raise ValueError(
+                "o_scale is only supported with the cutlass backend for now."
+            )
         if profiler_buffer is None:
             if self._use_profiler:
                 raise ValueError(
@@ -615,7 +1071,10 @@ def trtllm_batch_decode_with_kv_cache_mla(
     backend: str = "auto",
     is_var_seq: bool = True,
     uses_shared_paged_kv_idx: bool = True,
-) -> torch.Tensor:
+    lse: Optional[torch.Tensor] = None,

Supplemental: trtllm_batch_decode_with_kv_cache / trtllm_batch_context_with_kv_cache gained lse and return_lse parameters (signature widening — BC)

$ git diff v0.6.11.post3..main -- "flashinfer/decode.py" "flashinfer/prefill.py" | grep -B3 -A6 "return_lse: bool = False"
     uses_shared_paged_kv_idx: bool = True,
-) -> Union[torch.Tensor, FP4Tensor]:
+    lse: Optional[torch.Tensor] = None,
+    return_lse: bool = False,
+) -> Union[
+    torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]
+]:
     """
     Parameters
     ----------
--
     causal: bool = True,
-) -> Union[torch.Tensor, FP4Tensor]:
+    lse: Optional[torch.Tensor] = None,
+    return_lse: bool = False,
+) -> Union[
+    torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]
+]:
     """
     Parameters
     ----------

Summary by CodeRabbit

  • Chores
    • Version bumped to 0.6.12.

Review Change Stack

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 21, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 48d9e980-bbb3-4257-a887-c4d117a215b7

📥 Commits

Reviewing files that changed from the base of the PR and between 41e5aa2 and d771e62.

📒 Files selected for processing (1)
  • version.txt

📝 Walkthrough

Walkthrough

This PR updates the version string in version.txt from 0.6.11 to 0.6.12, recording a new release version.

Changes

Version Bump

Layer / File(s) Summary
Version string update
version.txt
Version.txt is updated to record the release version 0.6.12.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~1 minute

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • sricketts
  • yongwww
  • cyx-6
  • saltyminty
  • yyihuang
  • kahyunnam
  • nv-yunzheq

Poem

🐰 A version hops forward with hop, hop, cheer—
From point-six-eleven to twelve, crystal clear!
A bump in the code, a release so near,
The rabbit marks milestones throughout the year! 🌟

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title 'bump version to 0.6.12' directly and clearly describes the main change—a version bump from 0.6.11 to 0.6.12 as confirmed by the raw_summary.
Description check ✅ Passed The description includes the required sections: a clear description of the change ('Bump version to 0.6.12 for release'), related issues link, and comprehensive reviewer notes documenting API changes since the previous version.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch bump-version-0.6.12

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@claude
Copy link
Copy Markdown

claude Bot commented May 21, 2026

Code Review: Bump version to 0.6.12 - The diff is a single-line change (0.6.11 -> 0.6.12) and is correct. Review notes below.

@claude
Copy link
Copy Markdown

claude Bot commented May 21, 2026

Version Bump Assessment

The single-line diff (version.txt: 0.6.11 -> 0.6.12) is correct. Per the project versioning scheme, patch covers small backwards-compatible features and bug fixes. Most of the bundled changes (new APIs, additive parameter widening) fit this rule. Two items stand out as potentially breaking:

1. b12x_fused_moe removed parameters (quant_mode, source_format) — any caller passing these keyword arguments will get a TypeError at runtime. If this is part of the public API surface, this warrants a minor bump (0.7.0) or at minimum a deprecation cycle before removal.

2. create_cudnn_execution_plans_mxfp8_gemm renamed to build_cudnn_gemm_mxfp8_graph — the old name disappears entirely. The diff context suggests it is an @lru_cache internal helper, so impact may be limited — but confirming it is not re-exported publicly before tagging would be good.

If both affect downstream public users, 0.7.0 is more appropriate. If strictly internal, 0.6.12 is fine — just note them in the release changelog to avoid downstream surprises.

@claude
Copy link
Copy Markdown

claude Bot commented May 21, 2026

Observations on Bundled Changes (from PR description)

  • flush_graph_dumps(): Good addition for CUDA-graph deferred logging. Docstring and the synchronize parameter rationale are well explained.
  • BatchMLAPagedAttentionWrapper.run + o_scale: Input validation (math.isfinite, positive check, FP8 dtype check, shape/device guard against q_nope) looks solid. The cutlass-only guard for o_scale is clear.
  • trtllm_fp4_block_scale_routed_moe type widening to Union[Tensor, Tuple[Tensor,Tensor]]: Backwards-compatible.
  • trtllm_batch_decode_with_kv_cache_mla + lse/return_lse: Additive and BC.
  • gemm_fp8_nt_groupwise trace wiring: Good -- aligns with the trace checklist in CLAUDE.md.
  • fused_qk_rmsnorm_rope: New public API with @backend_requirement + @flashinfer_api. Please verify a TraceTemplate and tests/trace/example.py entry exist per the CLAUDE.md trace checklist.
  • checkpointing_ssu: Decorated with @flashinfer_api but no trace= argument visible in the diff. If this is a public API, it should carry a trace template per the same checklist.

Summary: The diff is correct. The main open question before merging is whether the b12x_fused_moe parameter removal and the cuDNN function rename are public-facing breaks (-> 0.7.0) or strictly internal (0.6.12 is fine). Also worth a quick check that checkpointing_ssu and fused_qk_rmsnorm_rope have trace templates wired up as project conventions require.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the version number in version.txt from 0.6.11 to 0.6.12. I have no feedback to provide.

@aleozlx
Copy link
Copy Markdown
Collaborator Author

aleozlx commented May 21, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !703 has been created, and the CI pipeline #52154394 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx merged commit 42e2200 into main May 22, 2026
33 checks passed
@aleozlx aleozlx deleted the bump-version-0.6.12 branch May 22, 2026 16:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants